In [ ]:
import torch as t
from datasets import load_dataset
import huggingface_hub as hf
from pathlib import Path
import einops
import plotly.graph_objects as go
from typing import Union, List, Optional
from jaxtyping import Float
from transformer_lens import ActivationCache
import circuitsvis as cv
from IPython.display import HTML
from plotly.subplots import make_subplots
from othello_gpt.data.vis import plot_in_basis, plot_game
from othello_gpt.util import (
get_all_squares,
load_model,
load_probes,
vocab_to_board,
)
from othello_gpt.data.vis import move_id_to_text
In [ ]:
root_dir = Path().cwd().parent.parent.parent
data_dir = root_dir / "data"
probe_dir = data_dir / "probes"
# hf.login((root_dir / "secret.txt").read_text())
dataset_dict = load_dataset("awonga/othello-gpt")
device = t.device(
"mps"
if t.backends.mps.is_available()
else "cuda"
if t.cuda.is_available()
else "cpu"
)
size = 6
all_squares = get_all_squares(size)
Resolving data files: 0%| | 0/87 [00:00<?, ?it/s]
Resolving data files: 0%| | 0/87 [00:00<?, ?it/s]
Loading dataset shards: 0%| | 0/87 [00:00<?, ?it/s]
In [ ]:
model = load_model(device, "awonga/othello-gpt-2M")
n_layer = model.cfg.n_layers
n_head = model.cfg.n_heads
d_head = model.cfg.d_head
d_model = model.cfg.d_model
n_neuron = model.cfg.d_model * 4
number of parameters: 1.58M
In [ ]:
n_test = 100
test_dataset = dataset_dict["test"].take(n_test)
probes = load_probes(
probe_dir,
device,
w_u=model.W_U.detach(),
w_e=model.W_E.T.detach(),
w_p=model.W_pos.T.detach(),
# combos=["t+m", "t-m", "t-e", "t-pt", "m-pm"],
combos=["+pee-ee"],
)
{k: p.shape for k, p in probes.items()} # d_model (row col) n_probe_layer
Out[Â ]:
{'ptm': torch.Size([128, 36, 17]),
'tm': torch.Size([128, 36, 17]),
'ee': torch.Size([128, 36, 17]),
'le': torch.Size([128, 36, 17]),
'pee': torch.Size([128, 36, 17]),
'tnpt': torch.Size([128, 36, 17]),
'u': torch.Size([128, 36, 17]),
'b': torch.Size([128, 36, 17]),
'p': torch.Size([128, 31, 17]),
'+pee-ee': torch.Size([128, 36, 17])}
In [ ]:
def visualize_attention_patterns(
heads: Union[List[int], int, Float[t.Tensor, "heads"]],
local_cache: ActivationCache,
local_tokens: t.Tensor,
title: Optional[str] = "",
max_width: Optional[int] = 700,
) -> str:
# If a single head is given, convert to a list
if isinstance(heads, int):
heads = [heads]
# Create the plotting data
labels: List[str] = []
patterns: List[Float[t.Tensor, "dest_pos src_pos"]] = []
# Assume we have a single batch item
batch_index = 0
for head in heads:
# Set the label
layer = head // model.cfg.n_heads
head_index = head % model.cfg.n_heads
labels.append(f"L{layer}H{head_index}")
# Get the attention patterns for the head
# Attention patterns have shape [batch, head_index, query_pos, key_pos]
patterns.append(local_cache["attn", layer][batch_index, head_index])
# Convert the tokens to strings (for the axis labels)
str_tokens = [move_id_to_text(t, size) for t in local_tokens]
# Combine the patterns into a single tensor
patterns: Float[t.Tensor, "head_index dest_pos src_pos"] = t.stack(
patterns, dim=0
).cpu()
# Normalise relative to 1/pos such that later rows don't get diluted
patterns *= (t.arange(patterns.shape[1]) + 1).unsqueeze(0).unsqueeze(-1)
# Circuitsvis Plot (note we get the code version so we can concatenate with the title)
plot = cv.circuitsvis.attention.attention_heads(
attention=patterns, tokens=str_tokens, attention_head_names=labels
).show_code()
# Display the title
title_html = f"<h2>{title}</h2><br/>"
# Return the visualisation as raw code
return f"<div style='max-width: {str(max_width)}px;'>{title_html + plot}</div>"
In [ ]:
for i in range(3):
test_game = test_dataset[i]
test_input_ids = t.tensor(test_game["input_ids"], device=device)
test_logits, test_cache = model.run_with_cache(test_input_ids[:-1])
vis = visualize_attention_patterns(
list(range(model.cfg.n_layers * model.cfg.n_heads)),
test_cache,
test_game["moves"],
)
display(HTML(vis))
plot_game(test_game)